{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 含并行连结的网络（GoogLeNet）\n",
    "\n",
    "在2014年的ImageNet图像识别挑战赛中，一个名叫GoogLeNet的网络结构大放异彩 [1]。它虽然在名字上向LeNet致敬，但在网络结构上已经很难看到LeNet的影子。GoogLeNet吸收了NiN中网络串联网络的思想，并在此基础上做了很大改进。在随后的几年里，研究人员对GoogLeNet进行了数次改进，本节将介绍这个模型系列的第一个版本。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "GoogLeNet中的基础卷积块叫作Inception块，得名于同名电影《盗梦空间》（Inception）。\n",
    "Inception块里有4条并行的线路。前3条线路使用窗口大小分别是$1\\times 1$、$3\\times 3$和$5\\times 5$的卷积层来抽取不同空间尺寸下的信息，其中中间2个线路会对输入先做$1\\times 1$卷积来减少输入通道数，以降低模型复杂度。第四条线路则使用$3\\times 3$最大池化层，后接$1\\times 1$卷积层来改变通道数。4条线路都使用了合适的填充来使输入与输出的高和宽一致。最后我们将每条线路的输出在通道维上连结，并输入接下来的层中去。\n",
    "\n",
    "Inception块中可以自定义的超参数是每个层的输出通道数，我们以此来控制模型复杂度。\n",
    "<div align=center>\n",
    "<img width=\"500\" src=\"../imgs/inception.png\"/>\n",
    "</div>\n",
    "<div align=center>Inception块的结构</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "import torch.nn.functional as F\n",
    "import sys\n",
    "sys.path.append(\"..\") \n",
    "\n",
    "torch.backends.cudnn.benchmark = True\n",
    "torch.backends.cudnn.deterministic = False\n",
    "torch.backends.cudnn.enabled = True\n",
    "\n",
    "import dl_utils\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "class Inception(nn.Module):\n",
    "    # c1 - c4为每条线路里的层的输出通道数\n",
    "    def __init__(self, in_c, c1, c2, c3, c4):\n",
    "        super(Inception, self).__init__()\n",
    "        # 线路1，单1 x 1卷积层\n",
    "        self.p1_1 = nn.Conv2d(in_c, c1, kernel_size=1)\n",
    "        # 线路2，1 x 1卷积层后接3 x 3卷积层\n",
    "        self.p2_1 = nn.Conv2d(in_c, c2[0], kernel_size=1)\n",
    "        self.p2_2 = nn.Conv2d(c2[0], c2[1], kernel_size=3, padding=1)\n",
    "        # 线路3，1 x 1卷积层后接5 x 5卷积层\n",
    "        self.p3_1 = nn.Conv2d(in_c, c3[0], kernel_size=1)\n",
    "        self.p3_2 = nn.Conv2d(c3[0], c3[1], kernel_size=5, padding=2)\n",
    "        # 线路4，3 x 3最大池化层后接1 x 1卷积层\n",
    "        self.p4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)\n",
    "        self.p4_2 = nn.Conv2d(in_c, c4, kernel_size=1)\n",
    "\n",
    "    def forward(self, x):\n",
    "\n",
    "        p1 = F.relu(self.p1_1(x))\n",
    "\n",
    "        p2 = F.relu(self.p2_2(F.relu(self.p2_1(x))))\n",
    "\n",
    "        p3 = F.relu(self.p3_2(F.relu(self.p3_1(x))))\n",
    "        p4 = F.relu(self.p4_2(self.p4_1(x)))\n",
    "        return torch.cat((p1, p2, p3, p4), dim=1)  # 在通道维上连结输出"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),\n",
    "                   nn.ReLU(),\n",
    "                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 64, 24, 24])\n"
     ]
    }
   ],
   "source": [
    "x = torch.randn(1,1,96,96)\n",
    "x = b1(x)\n",
    "print(x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "b2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1),\n",
    "                   nn.Conv2d(64, 192, kernel_size=3, padding=1),\n",
    "                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 192, 12, 12])\n"
     ]
    }
   ],
   "source": [
    "x = b2(x)\n",
    "print(x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "b3 = nn.Sequential(Inception(192, 64, (96, 128), (16, 32), 32),\n",
    "                   Inception(256, 128, (128, 192), (32, 96), 64),\n",
    "                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 256, 12, 12])\n",
      "torch.Size([1, 480, 6, 6])\n"
     ]
    }
   ],
   "source": [
    "test_inception = Inception(192, 64, (96, 128), (16, 32), 32)\n",
    "test_x = test_inception(x)\n",
    "print(test_x.shape)\n",
    "\n",
    "x = b3(x)\n",
    "print(x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 832, 3, 3])\n"
     ]
    }
   ],
   "source": [
    "b4 = nn.Sequential(Inception(480, 192, (96, 208), (16, 48), 64),\n",
    "                   Inception(512, 160, (112, 224), (24, 64), 64),\n",
    "                   Inception(512, 128, (128, 256), (24, 64), 64),\n",
    "                   Inception(512, 112, (144, 288), (32, 64), 64),\n",
    "                   Inception(528, 256, (160, 320), (32, 128), 128),\n",
    "                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n",
    "x = b4(x)\n",
    "print(x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 1024])\n"
     ]
    }
   ],
   "source": [
    "b5 = nn.Sequential(Inception(832, 256, (160, 320), (32, 128), 128),\n",
    "                   Inception(832, 384, (192, 384), (48, 128), 128),\n",
    "                  torch.nn.MaxPool2d(3),\n",
    "                  dl_utils.FlattenLayer())\n",
    "temp_b5 = b5(x)\n",
    "print(temp_b5.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): Sequential(\n",
      "    (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))\n",
      "    (1): ReLU()\n",
      "    (2): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (1): Sequential(\n",
      "    (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "    (1): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (2): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (2): Sequential(\n",
      "    (0): Inception(\n",
      "      (p1_1): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p2_1): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p2_2): Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      (p3_1): Conv2d(192, 16, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p3_2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
      "      (p4_1): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
      "      (p4_2): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))\n",
      "    )\n",
      "    (1): Inception(\n",
      "      (p1_1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p2_1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p2_2): Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      (p3_1): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p3_2): Conv2d(32, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
      "      (p4_1): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
      "      (p4_2): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "    )\n",
      "    (2): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (3): Sequential(\n",
      "    (0): Inception(\n",
      "      (p1_1): Conv2d(480, 192, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p2_1): Conv2d(480, 96, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p2_2): Conv2d(96, 208, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      (p3_1): Conv2d(480, 16, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p3_2): Conv2d(16, 48, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
      "      (p4_1): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
      "      (p4_2): Conv2d(480, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "    )\n",
      "    (1): Inception(\n",
      "      (p1_1): Conv2d(512, 160, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p2_1): Conv2d(512, 112, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p2_2): Conv2d(112, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      (p3_1): Conv2d(512, 24, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p3_2): Conv2d(24, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
      "      (p4_1): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
      "      (p4_2): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "    )\n",
      "    (2): Inception(\n",
      "      (p1_1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p2_1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p2_2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      (p3_1): Conv2d(512, 24, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p3_2): Conv2d(24, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
      "      (p4_1): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
      "      (p4_2): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "    )\n",
      "    (3): Inception(\n",
      "      (p1_1): Conv2d(512, 112, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p2_1): Conv2d(512, 144, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p2_2): Conv2d(144, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      (p3_1): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p3_2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
      "      (p4_1): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
      "      (p4_2): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "    )\n",
      "    (4): Inception(\n",
      "      (p1_1): Conv2d(528, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p2_1): Conv2d(528, 160, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p2_2): Conv2d(160, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      (p3_1): Conv2d(528, 32, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p3_2): Conv2d(32, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
      "      (p4_1): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
      "      (p4_2): Conv2d(528, 128, kernel_size=(1, 1), stride=(1, 1))\n",
      "    )\n",
      "    (5): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (4): Sequential(\n",
      "    (0): Inception(\n",
      "      (p1_1): Conv2d(832, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p2_1): Conv2d(832, 160, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p2_2): Conv2d(160, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      (p3_1): Conv2d(832, 32, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p3_2): Conv2d(32, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
      "      (p4_1): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
      "      (p4_2): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1))\n",
      "    )\n",
      "    (1): Inception(\n",
      "      (p1_1): Conv2d(832, 384, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p2_1): Conv2d(832, 192, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p2_2): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      (p3_1): Conv2d(832, 48, kernel_size=(1, 1), stride=(1, 1))\n",
      "      (p3_2): Conv2d(48, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
      "      (p4_1): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
      "      (p4_2): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1))\n",
      "    )\n",
      "    (2): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)\n",
      "    (3): FlattenLayer()\n",
      "  )\n",
      "  (5): Linear(in_features=1024, out_features=10, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "Inception_Net = nn.Sequential(b1, b2, b3, b4, b5, nn.Linear(1024, 10))\n",
    "print(Inception_Net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "output shape:  torch.Size([1, 64, 24, 24])\n",
      "output shape:  torch.Size([1, 192, 12, 12])\n",
      "output shape:  torch.Size([1, 480, 6, 6])\n",
      "output shape:  torch.Size([1, 832, 3, 3])\n",
      "output shape:  torch.Size([1, 1024])\n",
      "output shape:  torch.Size([1, 10])\n"
     ]
    }
   ],
   "source": [
    "X = torch.rand(1, 1, 96, 96)\n",
    "for blk in Inception_Net.children(): \n",
    "    X = blk(X)\n",
    "    print('output shape: ', X.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 获取数据和训练模型\n",
    "\n",
    "我们使用高和宽均为96像素的图像来训练GoogLeNet模型。训练使用的图像依然来自Fashion-MNIST数据集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 2.1584, train acc 0.162, test acc 0.673, time 72.6 sec\n",
      "epoch 2, loss 0.2728, train acc 0.802, test acc 0.813, time 66.7 sec\n",
      "epoch 3, loss 0.1280, train acc 0.860, test acc 0.853, time 63.3 sec\n",
      "epoch 4, loss 0.0829, train acc 0.879, test acc 0.875, time 61.9 sec\n",
      "epoch 5, loss 0.0590, train acc 0.893, test acc 0.888, time 62.6 sec\n"
     ]
    }
   ],
   "source": [
    "batch_size = 128\n",
    "# 如出现“out of memory”的报错信息，可减小batch_size或resize\n",
    "train_iter, test_iter = dl_utils.load_data_fashion_mnist(batch_size, resize=96)\n",
    "\n",
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(Inception_Net.parameters(), lr=lr)\n",
    "dl_utils.train_ch5(Inception_Net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
